{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Parameter Learning in Discrete Bayesian Networks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we demonstrate examples of learning the parameters (CPDs) of a Discrete Bayesian Network given the data and the model structure. pgmpy has three main algorithms for learning model parameters:\n",
    "\n",
    "1. **Maximum Likelihood Estimator** (`pgmpy.estimators.MaximumLikelihoodEstimator`): Simply estimates the Maximum Likelihood estimates of the parameters.\n",
    "2. **Bayesian Estimator** (`pgmpy.estimators.BayesianEstimator`): Allows users to specify priors. \n",
    "3. **Expectation Maximization** (`pgmpy.estimators.ExpectationMaximization`): Enables learning model parameters when latent variables are present in the model.\n",
    "\n",
    "Each of the parameter estimation classes have the following two methods:\n",
    "\n",
    "1. `estimate_cpd`: Estimates the CPD of the specified variable.\n",
    "2. `get_parameters`: Estimates the CPDs of all the variables in the model.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 0: Generate some simulated data and a model structure\n",
    "\n",
    "To do parameter estimation we require two things:\n",
    "1. **Data**: For the examples, we simulate some data from the alarm model (https://www.bnlearn.com/bnrepository/discrete-medium.html#alarm) and use it to learn back the model parameters.\n",
    "2. **Model structure**: We also need to specify the model structure to which to fit the data to. In this example, we simply use the structure to the alarm model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "07163638093f43efa83db9a97b16347b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/37 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: -1.4901161193847656e-08. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1175870895385742e-08. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1175870895385742e-08. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 2.2351741790771484e-08. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 2.2351741790771484e-08. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: -9.313225746154785e-09. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 2.2351741790771484e-08. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1175870895385742e-08. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: -9.313225746154785e-09. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: -1.6763806343078613e-08. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 2.2351741790771484e-08. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 2.2351741790771484e-08. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n",
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   EXPCO2  INTUBATION    PCWP   HREKG VENTLUNG  SAO2 VENTALV PULMEMBOLUS  \\\n",
      "0     LOW      NORMAL  NORMAL    HIGH     ZERO   LOW    ZERO       FALSE   \n",
      "1     LOW      NORMAL  NORMAL  NORMAL     ZERO   LOW    ZERO       FALSE   \n",
      "2     LOW      NORMAL    HIGH    HIGH      LOW  HIGH    HIGH       FALSE   \n",
      "3  NORMAL    ONESIDED  NORMAL    HIGH      LOW   LOW     LOW       FALSE   \n",
      "4     LOW  ESOPHAGEAL     LOW    HIGH     ZERO   LOW     LOW       FALSE   \n",
      "\n",
      "  ERRLOWOUTPUT      HR  ... LVFAILURE KINKEDTUBE HISTORY HYPOVOLEMIA  \\\n",
      "0        FALSE  NORMAL  ...     FALSE      FALSE   FALSE       FALSE   \n",
      "1        FALSE    HIGH  ...     FALSE      FALSE   FALSE       FALSE   \n",
      "2         TRUE    HIGH  ...     FALSE      FALSE   FALSE        TRUE   \n",
      "3        FALSE    HIGH  ...     FALSE      FALSE   FALSE       FALSE   \n",
      "4        FALSE    HIGH  ...     FALSE      FALSE   FALSE       FALSE   \n",
      "\n",
      "  STROKEVOLUME VENTMACH VENTTUBE     CVP   SHUNT MINVOLSET  \n",
      "0       NORMAL   NORMAL      LOW  NORMAL  NORMAL    NORMAL  \n",
      "1       NORMAL      LOW     ZERO  NORMAL  NORMAL    NORMAL  \n",
      "2       NORMAL     HIGH     HIGH    HIGH  NORMAL    NORMAL  \n",
      "3       NORMAL   NORMAL      LOW  NORMAL    HIGH    NORMAL  \n",
      "4       NORMAL   NORMAL     ZERO     LOW  NORMAL    NORMAL  \n",
      "\n",
      "[5 rows x 37 columns]\n"
     ]
    }
   ],
   "source": [
    "from pgmpy.utils import get_example_model\n",
    "from pgmpy.models import DiscreteBayesianNetwork\n",
    "\n",
    "# Load the alarm model and simulate data from it.\n",
    "alarm_model = get_example_model(model=\"alarm\")\n",
    "samples = alarm_model.simulate(n_samples=int(1e3))\n",
    "\n",
    "print(samples.head())\n",
    "\n",
    "# Define a new model with the same structure as the alarm model.\n",
    "new_model = DiscreteBayesianNetwork(ebunch=alarm_model.edges())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using the Maximumum Likelihood Estimator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pgmpy.estimators import MaximumLikelihoodEstimator\n",
    "\n",
    "# Initialize the estimator object.\n",
    "mle_est = MaximumLikelihoodEstimator(model=new_model, data=samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------+-------+\n",
      "| FIO2(LOW)    | 0.058 |\n",
      "+--------------+-------+\n",
      "| FIO2(NORMAL) | 0.942 |\n",
      "+--------------+-------+\n"
     ]
    }
   ],
   "source": [
    "# Estimate the CPD of the node FIO2.\n",
    "print(mle_est.estimate_cpd(node=\"FIO2\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<TabularCPD representing P(CVP:3 | LVEDVOLUME:3) at 0x7b043ec68860>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Estimate the CPD of node CVP\n",
    "mle_est.estimate_cpd(node=\"CVP\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<TabularCPD representing P(PCWP:3 | LVEDVOLUME:3) at 0x7b0427cb85f0>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Estimate all the CPDs for `new_model`\n",
    "all_cpds = mle_est.get_parameters(n_jobs=1)\n",
    "\n",
    "# Add the estimated CPDs to the model.\n",
    "new_model.add_cpds(*all_cpds)\n",
    "\n",
    "# Check if the CPDs are added to the model\n",
    "new_model.get_cpds('PCWP')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using the Bayesian Estimator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the Bayesian Estimator\n",
    "from pgmpy.estimators import BayesianEstimator\n",
    "\n",
    "be_est = BayesianEstimator(model=new_model, data=samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The estimator methods in `BayesianEstimator` class allows for a few different ways to specify the priors. The prior type can be chosen by specifying the `prior_type` argument. Please refer the documentation for details on different ways these priors can be specified: https://pgmpy.org/param_estimator/bayesian_est.html#bayesian-estimator\n",
    "\n",
    "1. **Dirichlet prior** (`prior_type=\"dirichlet\"`): Requires specifying `pseudo_counts` argument. The pseudo_counts arguments specifies the priors to use for the CPD estimation.\n",
    "2. **BDeu prior** (`prior_type=\"BDeu\"`): Requires specifying `equivalent_sample_size` arguemnt. The equivaluent_sample_size is used to compute the priors to use for CPD estimation. \n",
    "3. **K2** (`prior_type=\"K2\"`): Short hand for dirichlet prior with pseudo_count=1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------+-------+\n",
      "| FIO2(LOW)    | 0.279 |\n",
      "+--------------+-------+\n",
      "| FIO2(NORMAL) | 0.721 |\n",
      "+--------------+-------+\n",
      "+-------------+---------------------+-----+---------------------+\n",
      "| LVEDVOLUME  | LVEDVOLUME(HIGH)    | ... | LVEDVOLUME(NORMAL)  |\n",
      "+-------------+---------------------+-----+---------------------+\n",
      "| CVP(HIGH)   | 0.48841698841698844 | ... | 0.10685483870967742 |\n",
      "+-------------+---------------------+-----+---------------------+\n",
      "| CVP(LOW)    | 0.19884169884169883 | ... | 0.13709677419354838 |\n",
      "+-------------+---------------------+-----+---------------------+\n",
      "| CVP(NORMAL) | 0.3127413127413127  | ... | 0.7560483870967742  |\n",
      "+-------------+---------------------+-----+---------------------+\n"
     ]
    }
   ],
   "source": [
    "print(be_est.estimate_cpd(node=\"FIO2\", prior_type=\"BDeu\", equivalent_sample_size=1000))\n",
    "print(be_est.estimate_cpd(node=\"CVP\", prior_type=\"dirichlet\", pseudo_counts=100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<TabularCPD representing P(HYPOVOLEMIA:2) at 0x7b050bad4590>,\n",
       " <TabularCPD representing P(LVEDVOLUME:3 | HYPOVOLEMIA:2, LVFAILURE:2) at 0x7b0427cb9c40>,\n",
       " <TabularCPD representing P(STROKEVOLUME:3 | HYPOVOLEMIA:2, LVFAILURE:2) at 0x7b0427cb9250>,\n",
       " <TabularCPD representing P(CVP:3 | LVEDVOLUME:3) at 0x7b0427cb9070>,\n",
       " <TabularCPD representing P(PCWP:3 | LVEDVOLUME:3) at 0x7b0427cbb080>,\n",
       " <TabularCPD representing P(LVFAILURE:2) at 0x7b0427cbb110>,\n",
       " <TabularCPD representing P(HISTORY:2 | LVFAILURE:2) at 0x7b0427cbaf90>,\n",
       " <TabularCPD representing P(CO:3 | HR:3, STROKEVOLUME:3) at 0x7b0427cbaa20>,\n",
       " <TabularCPD representing P(ERRLOWOUTPUT:2) at 0x7b0427cbb020>,\n",
       " <TabularCPD representing P(HRBP:3 | ERRLOWOUTPUT:2, HR:3) at 0x7b0427cbb2c0>,\n",
       " <TabularCPD representing P(ERRCAUTER:2) at 0x7b0427cbb4d0>,\n",
       " <TabularCPD representing P(HREKG:3 | ERRCAUTER:2, HR:3) at 0x7b0427cbb1a0>,\n",
       " <TabularCPD representing P(HRSAT:3 | ERRCAUTER:2, HR:3) at 0x7b0427cbac00>,\n",
       " <TabularCPD representing P(INSUFFANESTH:2) at 0x7b0427cba3c0>,\n",
       " <TabularCPD representing P(CATECHOL:2 | ARTCO2:3, INSUFFANESTH:2, SAO2:3, TPR:3) at 0x7b0427cbb170>,\n",
       " <TabularCPD representing P(ANAPHYLAXIS:2) at 0x7b0427cbaea0>,\n",
       " <TabularCPD representing P(TPR:3 | ANAPHYLAXIS:2) at 0x7b0427cbb200>,\n",
       " <TabularCPD representing P(BP:3 | CO:3, TPR:3) at 0x7b0427cbb0b0>,\n",
       " <TabularCPD representing P(KINKEDTUBE:2) at 0x7b0427cba1b0>,\n",
       " <TabularCPD representing P(PRESS:4 | INTUBATION:3, KINKEDTUBE:2, VENTTUBE:4) at 0x7b0427cb98e0>,\n",
       " <TabularCPD representing P(VENTLUNG:4 | INTUBATION:3, KINKEDTUBE:2, VENTTUBE:4) at 0x7b0427cbb350>,\n",
       " <TabularCPD representing P(FIO2:2) at 0x7b0427cbade0>,\n",
       " <TabularCPD representing P(PVSAT:3 | FIO2:2, VENTALV:4) at 0x7b0427cba900>,\n",
       " <TabularCPD representing P(SAO2:3 | PVSAT:3, SHUNT:2) at 0x7b0427cb8d70>,\n",
       " <TabularCPD representing P(PULMEMBOLUS:2) at 0x7b0427cbb2f0>,\n",
       " <TabularCPD representing P(PAP:3 | PULMEMBOLUS:2) at 0x7b0427cb9790>,\n",
       " <TabularCPD representing P(SHUNT:2 | INTUBATION:3, PULMEMBOLUS:2) at 0x7b0427cbb3e0>,\n",
       " <TabularCPD representing P(INTUBATION:3) at 0x7b0427cbb4a0>,\n",
       " <TabularCPD representing P(MINVOL:4 | INTUBATION:3, VENTLUNG:4) at 0x7b0427cbb530>,\n",
       " <TabularCPD representing P(VENTALV:4 | INTUBATION:3, VENTLUNG:4) at 0x7b0427cbb560>,\n",
       " <TabularCPD representing P(DISCONNECT:2) at 0x7b0427cbb590>,\n",
       " <TabularCPD representing P(VENTTUBE:4 | DISCONNECT:2, VENTMACH:4) at 0x7b0427cbb5c0>,\n",
       " <TabularCPD representing P(MINVOLSET:3) at 0x7b0427cbb5f0>,\n",
       " <TabularCPD representing P(VENTMACH:4 | MINVOLSET:3) at 0x7b0427cbb620>,\n",
       " <TabularCPD representing P(EXPCO2:4 | ARTCO2:3, VENTLUNG:4) at 0x7b0427cbb650>,\n",
       " <TabularCPD representing P(ARTCO2:3 | VENTALV:4) at 0x7b0427cbb680>,\n",
       " <TabularCPD representing P(HR:3 | CATECHOL:2) at 0x7b0427cbb6b0>]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "be_est.get_parameters(prior_type=\"K2\", equivalent_sample_size=1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using Expectation Maximization "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The Expectation Maximization (EM) estimator can work in the case when latent variables are present in the model. To simulate this scenario, we will specify some of the variables in our `new_model` as latents and drop those variables from `samples` to simulate missing data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_latent = DiscreteBayesianNetwork(alarm_model.edges(), latents={'HISTORY', 'CVP'})\n",
    "samples_latent = samples.drop(['HISTORY', 'CVP'], axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9ce6966d366a427798a81f190072b282",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "[<TabularCPD representing P(EXPCO2:4 | ARTCO2:3, VENTLUNG:4) at 0x7b0427e26930>,\n",
       " <TabularCPD representing P(INTUBATION:3) at 0x7b0427cbbe90>,\n",
       " <TabularCPD representing P(PCWP:3 | LVEDVOLUME:3) at 0x7b0427e27590>,\n",
       " <TabularCPD representing P(HREKG:3 | ERRCAUTER:2, HR:3) at 0x7b0427d01fd0>,\n",
       " <TabularCPD representing P(VENTLUNG:4 | INTUBATION:3, KINKEDTUBE:2, VENTTUBE:4) at 0x7b0427cb9eb0>,\n",
       " <TabularCPD representing P(SAO2:3 | PVSAT:3, SHUNT:2) at 0x7b0427e27d10>,\n",
       " <TabularCPD representing P(VENTALV:4 | INTUBATION:3, VENTLUNG:4) at 0x7b0427d019a0>,\n",
       " <TabularCPD representing P(PULMEMBOLUS:2) at 0x7b0427e260f0>,\n",
       " <TabularCPD representing P(ERRLOWOUTPUT:2) at 0x7b0427d01910>,\n",
       " <TabularCPD representing P(HR:3 | CATECHOL:2) at 0x7b0427d01bb0>,\n",
       " <TabularCPD representing P(HRSAT:3 | ERRCAUTER:2, HR:3) at 0x7b0427d02480>,\n",
       " <TabularCPD representing P(DISCONNECT:2) at 0x7b0427ca7290>,\n",
       " <TabularCPD representing P(ERRCAUTER:2) at 0x7b0427d01f40>,\n",
       " <TabularCPD representing P(CO:3 | HR:3, STROKEVOLUME:3) at 0x7b0427cba150>,\n",
       " <TabularCPD representing P(BP:3 | CO:3, TPR:3) at 0x7b0427d02720>,\n",
       " <TabularCPD representing P(LVEDVOLUME:3 | HYPOVOLEMIA:2, LVFAILURE:2) at 0x7b0427d02870>,\n",
       " <TabularCPD representing P(ANAPHYLAXIS:2) at 0x7b0427ca6c90>,\n",
       " <TabularCPD representing P(TPR:3 | ANAPHYLAXIS:2) at 0x7b0427d025a0>,\n",
       " <TabularCPD representing P(HRBP:3 | ERRLOWOUTPUT:2, HR:3) at 0x7b0427cba0c0>,\n",
       " <TabularCPD representing P(PVSAT:3 | FIO2:2, VENTALV:4) at 0x7b0427d02b10>,\n",
       " <TabularCPD representing P(CATECHOL:2 | ARTCO2:3, INSUFFANESTH:2, SAO2:3, TPR:3) at 0x7b0427cb96d0>,\n",
       " <TabularCPD representing P(INSUFFANESTH:2) at 0x7b0427cba8d0>,\n",
       " <TabularCPD representing P(FIO2:2) at 0x7b0427d02390>,\n",
       " <TabularCPD representing P(ARTCO2:3 | VENTALV:4) at 0x7b0427d02ab0>,\n",
       " <TabularCPD representing P(PAP:3 | PULMEMBOLUS:2) at 0x7b0427d01af0>,\n",
       " <TabularCPD representing P(MINVOL:4 | INTUBATION:3, VENTLUNG:4) at 0x7b0427d02fc0>,\n",
       " <TabularCPD representing P(PRESS:4 | INTUBATION:3, KINKEDTUBE:2, VENTTUBE:4) at 0x7b0427cbadb0>,\n",
       " <TabularCPD representing P(LVFAILURE:2) at 0x7b044b06f4a0>,\n",
       " <TabularCPD representing P(KINKEDTUBE:2) at 0x7b0427d02e10>,\n",
       " <TabularCPD representing P(HYPOVOLEMIA:2) at 0x7b0427d03170>,\n",
       " <TabularCPD representing P(STROKEVOLUME:3 | HYPOVOLEMIA:2, LVFAILURE:2) at 0x7b0427cbafc0>,\n",
       " <TabularCPD representing P(VENTMACH:4 | MINVOLSET:3) at 0x7b0427d02f30>,\n",
       " <TabularCPD representing P(VENTTUBE:4 | DISCONNECT:2, VENTMACH:4) at 0x7b0427d01ca0>,\n",
       " <TabularCPD representing P(SHUNT:2 | INTUBATION:3, PULMEMBOLUS:2) at 0x7b0427ca6e10>,\n",
       " <TabularCPD representing P(MINVOLSET:3) at 0x7b0427e256d0>,\n",
       " <TabularCPD representing P(CVP:2 | LVEDVOLUME:3) at 0x7b043ef333b0>,\n",
       " <TabularCPD representing P(HISTORY:2 | LVFAILURE:2) at 0x7b043ec68b60>]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pgmpy.estimators import ExpectationMaximization as EM\n",
    "em_est = EM(model=model_latent, data=samples_latent)\n",
    "em_est.get_parameters()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Shortcut for learning and adding CPDs to the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `BayesianNetwork` class also provies a `fit` method that acts as a shortcut way to estimate and add CPDs to the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------+-------+\n",
      "| FIO2(LOW)    | 0.058 |\n",
      "+--------------+-------+\n",
      "| FIO2(NORMAL) | 0.942 |\n",
      "+--------------+-------+\n",
      "+--------------+-------+\n",
      "| FIO2(LOW)    | 0.279 |\n",
      "+--------------+-------+\n",
      "| FIO2(NORMAL) | 0.721 |\n",
      "+--------------+-------+\n"
     ]
    }
   ],
   "source": [
    "# Shortcut for learning all the parameters and adding the CPDs to the model.\n",
    "\n",
    "model_struct = DiscreteBayesianNetwork(ebunch=alarm_model.edges())\n",
    "model_struct.fit(data=samples, estimator=MaximumLikelihoodEstimator)\n",
    "print(model_struct.get_cpds(\"FIO2\"))\n",
    "\n",
    "model_struct = DiscreteBayesianNetwork(ebunch=alarm_model.edges())\n",
    "model_struct.fit(\n",
    "    data=samples,\n",
    "    estimator=BayesianEstimator,\n",
    "    prior_type=\"BDeu\",\n",
    "    equivalent_sample_size=1000,\n",
    ")\n",
    "print(model_struct.get_cpds(\"FIO2\"))"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
